from model.BaseModel import BaseModel
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tushare as ts
from sklearn.preprocessing import MinMaxScaler
import warnings

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM, SimpleRNN
from keras.layers import Dropout
from keras.layers import Conv1D, GlobalMaxPooling1D, MaxPooling1D, Flatten
from io import BytesIO
from PIL import Image
import base64
from sklearn.metrics import mean_squared_error  # 均方误差
from sklearn.metrics import mean_absolute_error  # 平方绝对误差


class CNNModel(BaseModel):

    def __init__(self):
        self.x_train = None
        self.x_test = None
        self.y_test = None
        self.y_train = None
        self.websocket = None
        self.history = None
        self.imgUrls = []
        self.all_data = None
        self.regressor = None
        self.code = None
        self.startTime = None
        self.endTime = None

    def setWS(self, ws):
        self.websocket = ws
        pass

    def setCode(self, code):
        self.code = str(code)
        pass

    def setTime(self, start, end):
        self.startTime = start
        self.endTime = end
        pass

    async def view(self):
        print(self.code)
        print(self.startTime)
        print(self.endTime)
        data = ts.get_k_data(self.code, start=self.startTime, end=self.endTime)
        d = []
        for index, row in data.iterrows():
            print(row['date'], row['open'], row['close'], row['high'], row['low'], row['volume'])
            d.append({
                'date': row['date'],
                'open': row['open'],
                'close': row['close'],
                'high': row['high'],
                'low': row['low'],
                'volume': row['volume'],
            })

        await self.websocket.send(self.result("data", "cnn-view-data", "200", "模型训练开始", d))
        pass

    async def train(self):
        print(self.code)
        print(self.startTime)
        print(self.endTime)
        await self.websocket.send(self.result("message", "cnn-train", "200", "模型训练开始"))
        # 通过tushare的接口获取浦发银行的历史数据
        data = ts.get_k_data(self.code, start=self.startTime, end=self.endTime)
        self.all_data = data.iloc[:, 1:6]
        sc = MinMaxScaler(feature_range=(0, 1), )
        print(self.all_data.head())
        all_data_scaled = sc.fit_transform(self.all_data)
        features = []
        labels = []
        for i in range(60, len(all_data_scaled)):
            features.append(all_data_scaled[i - 60:i, ])
            labels.append(all_data_scaled[i, 1])
        features, labels = np.array(features), np.array(labels)
        features = np.reshape(features, (features.shape[0], features.shape[1], -1))
        self.x_train, self.x_test, self.y_train, self.y_test = features[:1600], features[1600:], labels[:1600], labels[
                                                                                                                1600:]

        warnings.filterwarnings("ignore")
        filters = 250
        kernel_size = 3
        # Initialising the RNN
        self.regressor = Sequential()
        # Adding the first LSTM layer and some Dropout regularisation
        self.regressor.add(
            Conv1D(filters, kernel_size, padding='same', activation='relu', input_shape=(self.x_train.shape[1], 5)))
        self.regressor.add(Dropout(0.2))
        self.regressor.add(MaxPooling1D(2))  # 每两个取一个大的   数据会减少一半
        self.regressor.add(Flatten())  # 把二维数据变成一维的
        # Adding the output layer
        self.regressor.add(Dense(units=1))
        # Compiling the RNN
        self.regressor.compile(optimizer='adam', loss='mean_squared_error')
        history = self.regressor.fit(self.x_train, self.y_train, epochs=5, batch_size=32, validation_data=(
            self.x_test, self.y_test))
        self.history = history
        await self.websocket.send(self.result("message", "cnn-train", "200", "模型训练结束"))
        pass

    async def draw(self):
        self.imgUrls = []
        # 画损失曲线图
        history = self.history
        loss = history.history['loss']
        val_loss = history.history['val_loss']
        epochs = range(1, len(loss) + 1)
        plt.title('Loss curve')

        plt.plot(epochs, loss, 'red', label='Training loss')
        plt.plot(epochs, val_loss, 'blue', label='Validation loss')
        # 生成第一张图表
        buffer = BytesIO()
        plt.savefig(buffer, format='png')
        img_base64 = base64.b64encode(buffer.getvalue()).decode()
        self.imgUrls.append(img_base64)
        plt.legend()
        plt.show()

        sc_one = MinMaxScaler(feature_range=(0, 1))
        sc_one.fit_transform(self.all_data.iloc[:, 1:2])
        predicted_stock_train = self.regressor.predict(self.x_train)
        predicted_stock_train = sc_one.inverse_transform(predicted_stock_train)
        predicted_stock_test = self.regressor.predict(self.x_test)
        predicted_stock_test = sc_one.inverse_transform(predicted_stock_test)
        real_price_train = sc_one.inverse_transform(np.reshape(self.y_train, (-1, 1)))
        real_price_test = sc_one.inverse_transform(np.reshape(self.y_test, (-1, 1)))
        # Visualising the train results
        plt.plot(real_price_train, color='red', label='Real Stock Price')
        plt.plot(predicted_stock_train, color='blue', label='Predicted TAT Stock Price')
        plt.title('train Stock Price Prediction')
        plt.xlabel('Time')
        plt.ylabel('Stock Price')
        # 生成第二张图表
        buffer = BytesIO()
        plt.savefig(buffer, format='png')
        img_base64 = base64.b64encode(buffer.getvalue()).decode()
        self.imgUrls.append(img_base64)
        plt.legend()
        plt.show()

        # Visualising the test results
        plt.plot(real_price_test, color='red', label='Real Stock Price')
        plt.plot(predicted_stock_test, color='blue', label='Predicted TAT Stock Price')
        plt.title('test Stock Price Prediction')
        plt.xlabel('Time')
        plt.ylabel('Stock Price')
        # 生成第三张图表
        buffer = BytesIO()
        plt.savefig(buffer, format='png')
        img_base64 = base64.b64encode(buffer.getvalue()).decode()
        self.imgUrls.append(img_base64)
        plt.legend()
        plt.show()

        mse_score = mean_squared_error(real_price_test, predicted_stock_test)
        mae_score = mean_absolute_error(real_price_test, predicted_stock_test)
        print('测试集的均方误差是:', mse_score)
        print('测试集的平方绝对误差是:', mae_score)
        print(self.all_data.head())
        # 返回测试数据
        await  self.websocket.send(self.result("data", "cnn-draw", "200", "success", {
            "imgUrls": self.imgUrls,
            "mse_score": mse_score,
            "mae_score": mae_score,
            "all_data": str(self.all_data.head())
        }))

        pass
